using LinearAlgebra
using SparseArrays


# Constants
const hbar = 1.05457e-34;
const amu = 1.660539e-27;
const eps_0 = 8.854187e-12;
const bohr_radius = 5.291772e-11;
const debye = 3.33564e-30;
const k_boltzmann = 1.380649e-23;

const d_NaCs = 4.6debye;
const m_NaCs = 156amu;

# Basis:
# x ⊗ y ⊗ z

# 


# turn a multi-dimensional potential into a Hamiltonian with corresponding appropriate ordering of indices according to kron
function potential_to_hamiltonian(V_xyz::Array{Float64})::SparseMatrixCSC{Float64,Int64}
    dim_arr = size(V_xyz)
    potential_vector = zeros(prod(dim_arr))
    for idx_arr in Iterators.product(range.(1, dim_arr)...)
        basis_vec_arr = [sparsevec([aa], [1.], dim_aa) for (aa, dim_aa) in Iterators.zip(idx_arr, dim_arr)]
        basis_vec = kron(basis_vec_arr...)
        xyz_idx = basis_vec.nzind[1]
        potential_vector[xyz_idx] = V_xyz[idx_arr...]
    end

    return spdiagm(0 => potential_vector)
end


# DVR calculation: the kinetic Hamiltonian is the cross-term of three independent kinetic Hamiltonian's with the identity
"""
    Create the kinetic basis Hamiltonian according to Eq. (2.1) by Miller/Colbert
"""
function _unscaled_kinetic_hamiltonian(dim::Int64)::Matrix{Float64}
    K0 = diagm([dd => 2/dd^2 * (-1)^dd * ones(dim-abs(dd)) for dd in -dim:dim]...)
    K0[diagind(K0)] .= pi^2/3

    return K0
end


function _prefactor(mass::Float64, step_size::Float64)::Float64
    return hbar^2 / (2*mass*step_size^2)
end


function kinetic_hamiltonian(mass::Float64,
    dim_arr::Vector{Int64},
    step_size_arr::Vector{Float64}
)::SparseMatrixCSC{Float64,Int64}
    @assert length(dim_arr) == length(step_size_arr)
    num_dims = length(dim_arr)

    ham_1d_arr = _prefactor.(mass, step_size_arr) .* _unscaled_kinetic_hamiltonian.(dim_arr)
    id_arr = sparse(I.(dim_arr))

    ham_total = sum([
        kron(id_arr[1:aa-1]..., sparse(ham_1d_arr[aa]), id_arr[aa+1:num_dims]...)
            for aa in 1:num_dims
    ])

    return ham_total
end


# Full Hamiltonian
function dvr_hamiltonian(mass::Float64,
    step_size_arr::Vector{Float64},
    potential::Array{Float64}
)::SparseMatrixCSC{Float64,Int64}
    @assert length(step_size_arr) == ndims(potential)
    dim_arr = collect(size(potential))

    V = potential_to_hamiltonian(potential)
    T = kinetic_hamiltonian(mass, dim_arr, step_size_arr)

    return T + V
end


# Some tests
function harmonic_hamiltonian(
    mass::Float64,
    trap_frequency_arr::Vector{Float64},
    step_size_arr::Vector{Float64},
    dim_arr::Vector{Int64}
)::SparseMatrixCSC{Float64, Int64}
    @assert length(step_size_arr) == length(trap_frequency_arr)
    @assert length(dim_arr) == length(trap_frequency_arr)
    num_dims = ndims(trap_frequency_arr)
    
    # V = 1/2 m (2pi nu)^2 x^2
    spatial_grid_arr = step_size_arr .* [-(dim-1)/2:(dim-1)/2 for dim in dim_arr]
    potential = map(Iterators.product(spatial_grid_arr...)) do pos_vec
        1/2 * mass * sum( (2pi*trap_frequency_arr).^2 .* (pos_vec.^2) )
    end

    return dvr_hamiltonian(mass, step_size_arr, potential)
end
